""" Get the first stage functions

"""
from sklearn.linear_model import LinearRegression as Ols
from sklearn.linear_model import LogisticRegression as Logit

import numpy as np
import pandas as pd
import json
import sys
sys.path.append('./')

from src.models_new import first_stage_utils

#%% CONFIGURE ---------------------------------------------------------------------------
polynomial_order = 2
time_periods = ['month', 'fortnight']
use_previous_extensions = True

for t in time_periods:
    #%% READ IN THE DATA ----------------------------------------------------------------
    state = pd.read_csv(f'./data_py/processed/states_{t}.csv')
    mnl = pd.read_csv(f'./data_py/processed/mnl_{t}.csv')
    contracts = pd.read_csv('./data_py/processed/contracts_final.csv')

    #%% CONSTRUCT DATASET WITH POLYNOMIALS OF STATES ------------------------------------
    # Ensure that alternative states are used with fortnight
    if t == 'fortnight':
        contracts['g'] = contracts['g_fortnight']
        contracts['n_l'] = contracts['n_l_fortnight']
        contracts['n_m'] = contracts['n_m_fortnight']
        contracts['n_h'] = contracts['n_h_fortnight']
    (
        contracts_poly, mnl_poly, poly_names, poly
    ) = first_stage_utils.setup_data(state, mnl, contracts, polynomial_order)

    #%% SAVE PRELIMINARY DATASETS -------------------------------------------------------
    contracts_poly.to_csv(f'./models/first_stage/contracts_poly_{t}.csv')
    mnl_poly.to_csv(f'./models/first_stage/mnl_poly_{t}.csv')

    #%% GET THE POLICY FUNCTIONS --------------------------------------------------------
    output_mnl, reg_mnl = first_stage_utils.estimate_mnl(mnl_poly, poly_names, poly)
    output_ols, reg_ols = first_stage_utils.estimate_ols(contracts_poly, poly_names, poly)

    #%% SAVE THE OLS POLICY FUNCTIONS ---------------------------------------------------
    results_all_ols = dict()
    for spec in ['low', 'mid', 'high']:
        for metric in ['mri', 'value', 'day_rate']:
            results_all_ols[(spec, metric)] = (
                np.append(
                    np.array(reg_ols[f'{metric} {spec}'].intercept_),
                    reg_ols[f'{metric} {spec}'].coef_
                )
            )

    results_all_ols_df = pd.DataFrame(
        results_all_ols,
        index=['intercept'] + poly_names + ['tau']).T

    results_all_ols_df.to_csv(f'./models/first_stage/first_stage_ols_{t}.csv')

    #%% SAVE THE MNL POLICY FUNCTIONS ---------------------------------------------------
    results_all_mnl_coefs = dict()
    results_all_mnl_intercept = dict()
    for spec in ['low', 'mid', 'high']:
        for metric in ['tau']:
            results_all_mnl_coefs[(spec, 0)] = reg_mnl[f'{metric} {spec}'].coef_[0, :]
            results_all_mnl_coefs[(spec, 2)] = reg_mnl[f'{metric} {spec}'].coef_[1, :]
            results_all_mnl_coefs[(spec, 3)] = reg_mnl[f'{metric} {spec}'].coef_[2, :]
            results_all_mnl_coefs[(spec, 4)] = reg_mnl[f'{metric} {spec}'].coef_[3, :]

            results_all_mnl_intercept[(spec, 0)] = reg_mnl[f'{metric} {spec}'].intercept_[0]
            results_all_mnl_intercept[(spec, 2)] = reg_mnl[f'{metric} {spec}'].intercept_[1]
            results_all_mnl_intercept[(spec, 3)] = reg_mnl[f'{metric} {spec}'].intercept_[2]
            results_all_mnl_intercept[(spec, 4)] = reg_mnl[f'{metric} {spec}'].intercept_[3]

    # Get the policy function coefficients into a nice form
    results_all_mnl_coefs_df = pd.DataFrame(
        results_all_mnl_coefs,
        index=poly_names).T

    results_all_mnl_intercept_df = pd.DataFrame(
        results_all_mnl_intercept,
        index=['intercept']).T

    results_all_mnl_df = pd.concat(
        [results_all_mnl_coefs_df, results_all_mnl_intercept_df], axis=1)
    results_all_mnl_df = results_all_mnl_df[['intercept'] + poly_names]

    # Finally, save them
    results_all_mnl_df.to_csv(f'./models/first_stage/first_stage_mnl_{t}.csv')


    #%% GET N RIGS ----------------------------------------------------------------------
    rename_rows = {'n_total_low': 'low', 'n_total_mid': 'mid', 'n_total_high': 'high'}
    n_rigs = (
        state[['n_total_low', 'n_total_mid', 'n_total_high']]
        .mean()
        .round(0)
        .astype('int')
        .rename(index=rename_rows)
        .to_dict()
    )
    with open(f'./models/first_stage/n_rigs_{t}.json', 'w') as fp:
        json.dump(n_rigs, fp,  indent=4)


    #%% GET EXTENSIONS ------------------------------------------------------------------
    if use_previous_extensions:
        for spec in ['low', 'mid', 'high']:
            # Read in previous versions of the extensions (very tiny numerical changes across intel vs ARM vs amd machines here)
            df_extension = pd.read_csv(f'./src/models_new/extensions_numerics/first_stage_extensions_{t}.csv', index_col=[0, 1])
            df_extension_with_mri = pd.read_csv(f'./src/models_new/extensions_numerics/first_stage_extensions_with_mri_{t}.csv', index_col=[0])

            # Save
            df_extension.to_csv(f'./models/first_stage/first_stage_extensions_{t}.csv')
            df_extension_with_mri.to_csv(f'./models/first_stage/first_stage_extensions_with_mri_{t}.csv')
    else:
        # Get extension price
        contracts_poly.sort_values(['rig_name', 'fixture_date'], inplace=True)
        contracts_poly['dayrate_prev'] = contracts_poly['day_rate'].shift(1)
        extended = contracts_poly[contracts_poly['reneg'] == True]

        reg_extension_by_metric_spec = dict()
        reg_extension_with_mri_by_spec = dict()
        for spec in ['low', 'mid', 'high']:
            # Get the extension price
            extended_spec = extended[extended['spec'] == spec]
            reg_extension_price = Ols().fit(
                extended_spec[poly_names + ['tau']],
                extended_spec['day_rate']
            )
            reg_extension_by_metric_spec[(spec, 'price')] = np.append(
                np.array(reg_extension_price.intercept_),
                reg_extension_price.coef_
            )

            # Get extension probability
            not_initial_by_spec = contracts_poly[
                (contracts_poly['rig_name'] == contracts_poly['rig_name'].shift(1))
                & (contracts_poly['spec'] == spec)
            ]
            reg_extension_prob = Logit(max_iter=10000).fit(
                not_initial_by_spec[poly_names + ['tau']],
                not_initial_by_spec['reneg']
            )
            reg_extension_prob_with_mri = Logit(max_iter=10000).fit(
                not_initial_by_spec[poly_names + ['tau', 'mri']],
                not_initial_by_spec['reneg']
            )
            reg_extension_by_metric_spec[(spec, 'prob')] = np.append(
                np.array(reg_extension_prob.intercept_),
                reg_extension_prob.coef_
            )
            reg_extension_with_mri_by_spec[spec] = np.append(
                np.array(reg_extension_prob_with_mri.intercept_),
                reg_extension_prob_with_mri.coef_
            )

        # Get the extension probability
        df_extension = pd.DataFrame(
            reg_extension_by_metric_spec,
            index=['intercept'] + poly_names + ['tau']).T
        df_extension_with_mri = pd.DataFrame(
            reg_extension_with_mri_by_spec,
            index=['intercept'] + poly_names + ['tau', 'mri']).T

        # Finally, save them
        df_extension.to_csv(f'./models/first_stage/first_stage_extensions_{t}.csv')
        df_extension_with_mri.to_csv(f'./models/first_stage/first_stage_extensions_with_mri_{t}.csv')
